/***************************************************************************************************
 * Copyright (c) 2017-2021, NVIDIA CORPORATION.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without modification, are permitted
 * provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright notice, this list of
 *       conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright notice, this list of
 *       conditions and the following disclaimer in the documentation and/or other materials
 *       provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
 *       to endorse or promote products derived from this software without specific prior written
 *       permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
 * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 **************************************************************************************************/
/*! \file
    \brief Unit tests for thread-level GEMM
*/

#pragma once

#include "cutlass/cutlass.h"
#include "cutlass/aligned_buffer.h"
#include "cutlass/numeric_types.h"
#include "cutlass/subbyte_reference.h"
#include "cutlass/platform/platform.h"
#include "cutlass/arch/arch.h"

#include "cutlass/util/host_tensor.h"
#include "cutlass/util/tensor_view_io.h"

#include "cutlass/util/distribution.h"
#include "cutlass/util/reference/host/gemm.h"
#include "cutlass/util/reference/host/gemm_complex.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/host_reorder.h"
#include "cutlass/util/host_uncompress.h"

namespace test {
namespace gemm {
namespace warp {

/////////////////////////////////////////////////////////////////////////////////////////////////

/// Test kernel
template <typename Mma, typename ThreadblockShape>
__global__ void kernel(typename Mma::ElementC* output_C,
                       typename Mma::ElementA const* input_A,
                       typename Mma::ElementB const* input_B,
                       typename Mma::ElementC const* input_C,
                       int iterations = 1) {
    // Use AlignedBuffer to store trivially copyable objects in unions and
    // __shared__ buffers.
    __shared__ cutlass::AlignedBuffer<
            typename Mma::ElementA, ThreadblockShape::kM * ThreadblockShape::kK>
            smem_buffer_A;

    __shared__ cutlass::AlignedBuffer<
            typename Mma::ElementB, ThreadblockShape::kN * ThreadblockShape::kK>
            smem_buffer_B;

    if (threadIdx.x == 0) {
        typename Mma::ElementA* smem_ptr_A = smem_buffer_A.data();
#pragma unroll 1
        for (int i = 0; i < smem_buffer_A.size(); ++i) {
            cutlass::ReferenceFactory<typename Mma::ElementA>::get(smem_ptr_A,
                                                                   i) =
                    cutlass::ReferenceFactory<
                            typename cutlass::platform::remove_const<
                                    typename Mma::ElementA>::type>::get(input_A,
                                                                        i);
        }

        typename Mma::ElementB* smem_ptr_B = smem_buffer_B.data();
#pragma unroll 1
        for (int i = 0; i < smem_buffer_B.size(); ++i) {
            cutlass::ReferenceFactory<typename Mma::ElementB>::get(smem_ptr_B,
                                                                   i) =
                    cutlass::ReferenceFactory<
                            typename cutlass::platform::remove_const<
                                    typename Mma::ElementB>::type>::get(input_B,
                                                                        i);
        }
    }

    __syncthreads();

    //
    // Construct warp-level matrix product
    //

    using FragmentA = typename Mma::FragmentA;
    using FragmentB = typename Mma::FragmentB;
    using FragmentC = typename Mma::FragmentC;

    typename Mma::LayoutA layout_A =
            Mma::LayoutA::packed({ThreadblockShape::kM, ThreadblockShape::kK});
    typename Mma::LayoutB layout_B =
            Mma::LayoutB::packed({ThreadblockShape::kK, ThreadblockShape::kN});
    typename Mma::LayoutC layout_C =
            Mma::LayoutC::packed({Mma::Shape::kM, Mma::Shape::kN});

    typename Mma::IteratorA iter_A({smem_buffer_A.data(), layout_A},
                                   cutlass::arch::LaneId());

    typename Mma::IteratorB iter_B({smem_buffer_B.data(), layout_B},
                                   cutlass::arch::LaneId());

    FragmentA frag_A;
    FragmentB frag_B;

    FragmentC accum;

    Mma mma;

    accum.clear();

    CUTLASS_PRAGMA_NO_UNROLL
    for (int iter = 0; iter < iterations;
         ++iter) {  // place in loop that is not unrolled

        CUTLASS_PRAGMA_UNROLL
        for (int k = 0; k < ThreadblockShape::kK;
             k += Mma::Policy::MmaShape::kK) {
            iter_A.load(frag_A);
            iter_B.load(frag_B);

            ++iter_A;
            ++iter_B;

            mma(accum, frag_A, frag_B, accum);
        }
    }

    typename Mma::IteratorC iter_C({output_C, layout_C},
                                   cutlass::arch::LaneId());

    iter_C.store(accum);
}

/////////////////////////////////////////////////////////////////////////////////////////////////

/// Structure to compute the matrix product
template <
        /// Warp-level matrix multiply-accumulate
        typename Mma_,
        /// Size of threadblock-scoped shape used to store SMEM
        typename ThreadblockShape_,
        /// The inner product operation performed by GEMM
        typename Operator_ = cutlass::arch::OpMultiplyAdd>
struct Testbed {
    /// Thread-level matrix multiply-accumulate operator
    using Mma = Mma_;
    using ThreadblockShape = ThreadblockShape_;
    using Operator = Operator_;

    using Shape = typename Mma::Shape;
    using ElementA = typename Mma::ElementA;
    using LayoutA = typename Mma::LayoutA;
    using ElementB = typename Mma::ElementB;
    using LayoutB = typename Mma::LayoutB;
    using ElementC = typename Mma::ElementC;
    using LayoutC = typename Mma::LayoutC;

    //
    // Data members
    //

    cutlass::HostTensor<ElementA, LayoutA> tensor_A;
    cutlass::HostTensor<ElementB, LayoutB> tensor_B;
    cutlass::HostTensor<ElementC, LayoutC> tensor_C;
    cutlass::HostTensor<ElementC, LayoutC> tensor_D_computed;
    cutlass::HostTensor<ElementC, LayoutC> tensor_D_reference;

    //
    // Methods
    //

    /// Allocates workspace in device memory
    Testbed() {
        tensor_A.reset(cutlass::make_Coord(ThreadblockShape::kM,
                                           ThreadblockShape::kK));
        tensor_B.reset(cutlass::make_Coord(ThreadblockShape::kK,
                                           ThreadblockShape::kN));
        tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN));
        tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN));
        tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN),
                                 false);
    }

    /// Runs the test
    bool
    run(cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform,
        cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) {
        //
        // initialize device memory
        //

        if (init_A == cutlass::Distribution::Uniform) {
            int scope_max = 8;
            int scope_min = -8;

            if (cutlass::sizeof_bits<ElementA>::value == 4) {
                scope_max = 2;
                scope_min = -2;
            } else if (cutlass::sizeof_bits<ElementA>::value == 1) {
                scope_max = 2;
                scope_min = 0;
            }

            uint64_t seed = 7;
            cutlass::reference::host::TensorFillRandomUniform(
                    tensor_A.host_view(), seed, scope_max, scope_min, 0);
        } else if (init_A == cutlass::Distribution::Sequential) {
            cutlass::reference::host::BlockFillSequential(tensor_A.host_data(),
                                                          tensor_A.capacity());
        } else if (init_A == cutlass::Distribution::Identity) {
            cutlass::reference::host::TensorFillIdentity(tensor_A.host_view());
        } else {
            // TODO: Implement the rest
            return false;
        }

        if (init_B == cutlass::Distribution::Uniform) {
            int scope_max = 8;
            int scope_min = -8;

            if (cutlass::sizeof_bits<ElementB>::value == 4) {
                scope_max = 2;
                scope_min = -2;
            } else if (cutlass::sizeof_bits<ElementB>::value == 1) {
                scope_max = 2;
                scope_min = 0;
            }

            uint64_t seed = 7;
            cutlass::reference::host::TensorFillRandomUniform(
                    tensor_B.host_view(), seed + 16, scope_max, scope_min, 0);
        } else if (init_B == cutlass::Distribution::Sequential) {
            cutlass::reference::host::BlockFillSequential(tensor_B.host_data(),
                                                          tensor_B.capacity());
        } else if (init_B == cutlass::Distribution::Identity) {
            cutlass::reference::host::TensorFillIdentity(tensor_B.host_view());
        } else {
            // TODO: Implement the rest
            return false;
        }

        cutlass::reference::host::TensorFill(tensor_C.host_view(), ElementC(0));

        cutlass::reference::host::TensorFill(tensor_D_computed.host_view(),
                                             ElementC(0));

        cutlass::reference::host::TensorFill(tensor_D_reference.host_view(),
                                             ElementC(0));

        tensor_A.sync_device();
        tensor_B.sync_device();
        tensor_C.sync_device();
        tensor_D_computed.sync_device();

        // launch kernel
        kernel<Mma, ThreadblockShape><<<dim3(1, 1), dim3(32, 1, 1)>>>(
                tensor_D_computed.device_data(), tensor_A.device_data(),
                tensor_B.device_data(), tensor_C.device_data());

        // verify no errors
        cudaError_t result = cudaDeviceSynchronize();

        EXPECT_EQ(result, cudaSuccess)
                << "CUDA ERROR: " << cudaGetErrorString(result);
        if (result != cudaSuccess) {
            return false;
        }

        tensor_D_computed.sync_host();

        //
        // Reference implementation
        //

        cutlass::reference::host::Gemm<ElementA, LayoutA, ElementB, LayoutB,
                                       ElementC, LayoutC, ElementC, ElementC,
                                       Operator>
                reference_gemm;

        reference_gemm({Shape::kM, Shape::kN, ThreadblockShape::kK},
                       ElementC(1), tensor_A.host_ref(), tensor_B.host_ref(),
                       ElementC(0), tensor_D_reference.host_ref());

        //
        // Verify equivalence
        //

        // compare
        bool passed = cutlass::reference::host::TensorEquals(
                tensor_D_computed.host_view(), tensor_D_reference.host_view());

        EXPECT_TRUE(passed);

        if (!passed) {
            cutlass::TensorView<ElementA, cutlass::layout::ColumnMajor>
                    tensor_A_physical(tensor_A.host_data(),
                                      tensor_A.stride()[0], tensor_A.extent());

            cutlass::TensorView<ElementB, cutlass::layout::RowMajor>
                    tensor_B_physical(tensor_B.host_data(),
                                      tensor_B.stride()[0], tensor_B.extent());

            std::cout << "cutlass::sizeof_bits<ElementA>::value = "
                      << cutlass::sizeof_bits<ElementA>::value << "\n";
            std::cout << "A:\n"
                      << tensor_A.host_view() << "\n\n"
                      << "A(physical - stride: " << tensor_A.stride()[0]
                      << ", extent: " << tensor_A.extent() << "):\n"
                      << tensor_A_physical << "\n\n";

            std::cout << "cutlass::sizeof_bits<ElementB>::value = "
                      << cutlass::sizeof_bits<ElementB>::value << "\n";
            std::cout << "B:\n"
                      << tensor_B.host_view() << "\n\n"
                      << "B(physical - stride: " << tensor_B.stride()[0]
                      << ", extent: " << tensor_B.extent() << "):\n"
                      << tensor_B_physical << "\n\n";

            std::cout << "C:\n"
                      << tensor_C.host_view() << "\n\n"
                      << "Reference:\n"
                      << tensor_D_reference.host_view() << "\n\n"
                      << "Computed:\n"
                      << tensor_D_computed.host_view() << std::endl;
        }

        return passed;
    }
};

/////////////////////////////////////////////////////////////////////////////////////////////////

/// Structure to compute the matrix product
template <
        /// Warp-level matrix multiply-accumulate
        typename Mma_,
        /// Size of threadblock-scoped shape used to store SMEM
        typename ThreadblockShape_>
struct TestbedComplex {
    /// Thread-level matrix multiply-accumulate operator
    using Mma = Mma_;
    using ThreadblockShape = ThreadblockShape_;

    using Shape = typename Mma::Shape;
    using ElementA = typename Mma::ElementA;
    using LayoutA = typename Mma::LayoutA;
    using ElementB = typename Mma::ElementB;
    using LayoutB = typename Mma::LayoutB;
    using ElementC = typename Mma::ElementC;
    using LayoutC = typename Mma::LayoutC;

    //
    // Data members
    //

    cutlass::HostTensor<ElementA, LayoutA> tensor_A;
    cutlass::HostTensor<ElementB, LayoutB> tensor_B;
    cutlass::HostTensor<ElementC, LayoutC> tensor_C;
    cutlass::HostTensor<ElementC, LayoutC> tensor_D_computed;
    cutlass::HostTensor<ElementC, LayoutC> tensor_D_reference;

    //
    // Methods
    //

    /// Allocates workspace in device memory
    TestbedComplex() {
        tensor_A.reset(cutlass::make_Coord(ThreadblockShape::kM,
                                           ThreadblockShape::kK));
        tensor_B.reset(cutlass::make_Coord(ThreadblockShape::kK,
                                           ThreadblockShape::kN));
        tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN));
        tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN));
        tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN),
                                 false);
    }

    /// Runs the test
    bool
    run(cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform,
        cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) {
        //
        // initialize device memory
        //

        if (init_A == cutlass::Distribution::Uniform) {
            uint64_t seed = 7;
            cutlass::reference::host::TensorFillRandomUniform(
                    tensor_A.host_view(), seed, 8, -8, 0);
        } else if (init_A == cutlass::Distribution::Sequential) {
            cutlass::reference::host::BlockFillSequential(tensor_A.host_data(),
                                                          tensor_A.capacity());
        } else if (init_A == cutlass::Distribution::Identity) {
            cutlass::reference::host::TensorFillIdentity(tensor_A.host_view());
        } else {
            // TODO: Implement the rest
            return false;
        }

        if (init_B == cutlass::Distribution::Uniform) {
            uint64_t seed = 7;
            cutlass::reference::host::TensorFillRandomUniform(
                    tensor_B.host_view(), seed + 16, 8, -8, 0);
        } else if (init_B == cutlass::Distribution::Sequential) {
            cutlass::reference::host::BlockFillSequential(tensor_B.host_data(),
                                                          tensor_B.capacity());
        } else if (init_B == cutlass::Distribution::Identity) {
            cutlass::reference::host::TensorFillIdentity(tensor_B.host_view());
        } else {
            // TODO: Implement the rest
            return false;
        }

        cutlass::reference::host::TensorFill(tensor_C.host_view(), ElementC(0));

        cutlass::reference::host::TensorFill(tensor_D_computed.host_view(),
                                             ElementC(0));

        cutlass::reference::host::TensorFill(tensor_D_reference.host_view(),
                                             ElementC(0));

        tensor_A.sync_device();
        tensor_B.sync_device();
        tensor_C.sync_device();
        tensor_D_computed.sync_device();

        // launch kernel
        kernel<Mma, ThreadblockShape><<<dim3(1, 1), dim3(32, 1, 1)>>>(
                tensor_D_computed.device_data(), tensor_A.device_data(),
                tensor_B.device_data(), tensor_C.device_data());

        // verify no errors
        cudaError_t result = cudaDeviceSynchronize();

        EXPECT_EQ(result, cudaSuccess)
                << "CUDA ERROR: " << cudaGetErrorString(result);
        if (result != cudaSuccess) {
            return false;
        }

        tensor_D_computed.sync_host();

        //
        // Reference implementation
        //

        cutlass::reference::host::GemmComplex(
                {Shape::kM, Shape::kN, ThreadblockShape::kK}, ElementC(1),
                tensor_A.host_ref(), Mma::kTransformA, tensor_B.host_ref(),
                Mma::kTransformB, ElementC(0), tensor_C.host_ref(),
                tensor_D_reference.host_ref());

        //
        // Verify equivalence
        //

        // compare
        bool passed = cutlass::reference::host::TensorEquals(
                tensor_D_computed.host_view(), tensor_D_reference.host_view());

        EXPECT_TRUE(passed);

        if (!passed) {
            cutlass::TensorView<ElementA, cutlass::layout::ColumnMajor>
                    tensor_A_physical(tensor_A.host_data(),
                                      tensor_A.stride()[0], tensor_A.extent());

            cutlass::TensorView<ElementB, cutlass::layout::RowMajor>
                    tensor_B_physical(tensor_B.host_data(),
                                      tensor_B.stride()[0], tensor_B.extent());

            std::cout << "cutlass::sizeof_bits<ElementA>::value = "
                      << cutlass::sizeof_bits<ElementA>::value << "\n";
            std::cout << "A:\n"
                      << tensor_A.host_view() << "\n\n"
                      << "A(physical - stride: " << tensor_A.stride()[0]
                      << ", extent: " << tensor_A.extent() << "):\n"
                      << tensor_A_physical << "\n\n";

            std::cout << "cutlass::sizeof_bits<ElementB>::value = "
                      << cutlass::sizeof_bits<ElementB>::value << "\n";
            std::cout << "B:\n"
                      << tensor_B.host_view() << "\n\n"
                      << "B(physical - stride: " << tensor_B.stride()[0]
                      << ", extent: " << tensor_B.extent() << "):\n"
                      << tensor_B_physical << "\n\n";

            std::cout << "C:\n"
                      << tensor_C.host_view() << "\n\n"
                      << "Reference:\n"
                      << tensor_D_reference.host_view() << "\n\n"
                      << "Computed:\n"
                      << tensor_D_computed.host_view() << std::endl;
        }

        return passed;
    }
};

/////////////////////////////////////////////////////////////////////////////////////////////////

/// Test kernel
template <typename Mma, typename ThreadblockShape>
__global__ void kernel_transform(typename Mma::ElementC* output_C,
                                 typename Mma::ElementA const* input_A,
                                 typename Mma::ElementB const* input_B,
                                 typename Mma::ElementC const* input_C,
                                 int iterations = 1) {
    // Use AlignedBuffer to store trivially copyable objects in unions and
    // __shared__ buffers.
    __shared__ cutlass::AlignedBuffer<
            typename Mma::ElementA, ThreadblockShape::kM * ThreadblockShape::kK>
            smem_buffer_A;

    __shared__ cutlass::AlignedBuffer<
            typename Mma::ElementB, ThreadblockShape::kN * ThreadblockShape::kK>
            smem_buffer_B;

    if (threadIdx.x == 0) {
        typename Mma::ElementA* smem_ptr_A = smem_buffer_A.data();
#pragma unroll 1
        for (int i = 0; i < smem_buffer_A.size(); ++i) {
            cutlass::ReferenceFactory<typename Mma::ElementA>::get(smem_ptr_A,
                                                                   i) =
                    cutlass::ReferenceFactory<
                            typename cutlass::platform::remove_const<
                                    typename Mma::ElementA>::type>::get(input_A,
                                                                        i);
        }

        typename Mma::ElementB* smem_ptr_B = smem_buffer_B.data();
#pragma unroll 1
        for (int i = 0; i < smem_buffer_B.size(); ++i) {
            cutlass::ReferenceFactory<typename Mma::ElementB>::get(smem_ptr_B,
                                                                   i) =
                    cutlass::ReferenceFactory<
                            typename cutlass::platform::remove_const<
                                    typename Mma::ElementB>::type>::get(input_B,
                                                                        i);
        }
    }

    __syncthreads();

    //
    // Construct warp-level matrix product
    //

    using FragmentA = typename Mma::FragmentA;
    using FragmentB = typename Mma::FragmentB;
    using FragmentC = typename Mma::FragmentC;

    using TransformedFragmentA = typename Mma::TransformedFragmentA;
    using TransformedFragmentB = typename Mma::TransformedFragmentB;

    typename Mma::LayoutA layout_A =
            Mma::LayoutA::packed({ThreadblockShape::kM, ThreadblockShape::kK});
    typename Mma::LayoutB layout_B =
            Mma::LayoutB::packed({ThreadblockShape::kK, ThreadblockShape::kN});
    typename Mma::LayoutC layout_C =
            Mma::LayoutC::packed({Mma::Shape::kM, Mma::Shape::kN});

    typename Mma::IteratorA iter_A({smem_buffer_A.data(), layout_A},
                                   cutlass::arch::LaneId());

    typename Mma::IteratorB iter_B({smem_buffer_B.data(), layout_B},
                                   cutlass::arch::LaneId());

    FragmentA loaded_frag_A;
    FragmentB loaded_frag_B;
    TransformedFragmentA transformed_frag_A;
    TransformedFragmentB transformed_frag_B;

    FragmentC accum;

    Mma mma;

    accum.clear();

    CUTLASS_PRAGMA_NO_UNROLL
    for (int iter = 0; iter < iterations;
         ++iter) {  // place in loop that is not unrolled

        CUTLASS_PRAGMA_UNROLL
        for (int k = 0; k < ThreadblockShape::kK;
             k += Mma::Policy::MmaShape::kK) {
            iter_A.load(loaded_frag_A);
            iter_B.load(loaded_frag_B);

            ++iter_A;
            ++iter_B;

            mma.transform(transformed_frag_A, transformed_frag_B, loaded_frag_A,
                          loaded_frag_B);

            mma(accum, transformed_frag_A, transformed_frag_B, accum);
        }
    }

    typename Mma::IteratorC iter_C({output_C, layout_C},
                                   cutlass::arch::LaneId());

    iter_C.store(accum);
}

/////////////////////////////////////////////////////////////////////////////////////////////////

/// Structure to compute the matrix product
template <
        /// Warp-level matrix multiply-accumulate
        typename Mma_,
        /// Size of threadblock-scoped shape used to store SMEM
        typename ThreadblockShape_,
        /// The innter product operation performed by GEMM
        typename Operator_ = cutlass::arch::OpMultiplyAdd>
struct TransformTestbed {
    /// Thread-level matrix multiply-accumulate operator
    using Mma = Mma_;
    using ThreadblockShape = ThreadblockShape_;
    using Operator = Operator_;

    using Shape = typename Mma::Shape;
    using ElementA = typename Mma::ElementA;
    using LayoutA = typename Mma::LayoutA;
    using ElementB = typename Mma::ElementB;
    using LayoutB = typename Mma::LayoutB;
    using ElementC = typename Mma::ElementC;
    using LayoutC = typename Mma::LayoutC;

    //
    // Data members
    //

    cutlass::HostTensor<ElementA, LayoutA> tensor_A;
    cutlass::HostTensor<ElementB, LayoutB> tensor_B;
    cutlass::HostTensor<ElementC, LayoutC> tensor_C;
    cutlass::HostTensor<ElementC, LayoutC> tensor_D_computed;
    cutlass::HostTensor<ElementC, LayoutC> tensor_D_reference;

    //
    // Methods
    //

    /// Allocates workspace in device memory
    TransformTestbed() {
        tensor_A.reset(cutlass::make_Coord(ThreadblockShape::kM,
                                           ThreadblockShape::kK));
        tensor_B.reset(cutlass::make_Coord(ThreadblockShape::kK,
                                           ThreadblockShape::kN));
        tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN));
        tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN));
        tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN),
                                 false);
    }

    /// Runs the test
    bool
    run(cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform,
        cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) {
        //
        // initialize device memory
        //

        if (init_A == cutlass::Distribution::Uniform) {
            int scope_max = 8;
            int scope_min = -8;

            if (cutlass::sizeof_bits<ElementA>::value == 4) {
                scope_max = 2;
                scope_min = -2;
            } else if (cutlass::sizeof_bits<ElementA>::value == 1) {
                scope_max = 2;
                scope_min = 0;
            }

            uint64_t seed = 7;
            cutlass::reference::host::TensorFillRandomUniform(
                    tensor_A.host_view(), seed, scope_max, scope_min, 0);
        } else if (init_A == cutlass::Distribution::Sequential) {
            cutlass::reference::host::BlockFillSequential(tensor_A.host_data(),
                                                          tensor_A.capacity());
        } else if (init_A == cutlass::Distribution::Identity) {
            cutlass::reference::host::TensorFillIdentity(tensor_A.host_view());
        } else {
            // TODO: Implement the rest
            return false;
        }

        if (init_B == cutlass::Distribution::Uniform) {
            int scope_max = 8;
            int scope_min = -8;

            if (cutlass::sizeof_bits<ElementB>::value == 4) {
                scope_max = 2;
                scope_min = -2;
            } else if (cutlass::sizeof_bits<ElementB>::value == 1) {
                scope_max = 2;
                scope_min = 0;
            }

            uint64_t seed = 7;
            cutlass::reference::host::TensorFillRandomUniform(
                    tensor_B.host_view(), seed + 16, scope_max, scope_min, 0);
        } else if (init_B == cutlass::Distribution::Sequential) {
            cutlass::reference::host::BlockFillSequential(tensor_B.host_data(),
                                                          tensor_B.capacity());
        } else if (init_B == cutlass::Distribution::Identity) {
            cutlass::reference::host::TensorFillIdentity(tensor_B.host_view());
        } else {
            // TODO: Implement the rest
            return false;
        }

        cutlass::reference::host::TensorFill(tensor_C.host_view(), ElementC(0));

        cutlass::reference::host::TensorFill(tensor_D_computed.host_view(),
                                             ElementC(0));

        cutlass::reference::host::TensorFill(tensor_D_reference.host_view(),
                                             ElementC(0));

        tensor_A.sync_device();
        tensor_B.sync_device();
        tensor_C.sync_device();
        tensor_D_computed.sync_device();

        // launch kernel
        kernel_transform<Mma, ThreadblockShape><<<dim3(1, 1), dim3(32, 1, 1)>>>(
                tensor_D_computed.device_data(), tensor_A.device_data(),
                tensor_B.device_data(), tensor_C.device_data());

        // verify no errors
        cudaError_t result = cudaDeviceSynchronize();

        EXPECT_EQ(result, cudaSuccess)
                << "CUDA ERROR: " << cudaGetErrorString(result);
        if (result != cudaSuccess) {
            return false;
        }

        tensor_D_computed.sync_host();

        //
        // Reference implementation
        //

        cutlass::reference::host::Gemm<ElementA, LayoutA, ElementB, LayoutB,
                                       ElementC, LayoutC, ElementC, ElementC,
                                       Operator>
                reference_gemm;

        reference_gemm({Shape::kM, Shape::kN, ThreadblockShape::kK},
                       ElementC(1), tensor_A.host_ref(), tensor_B.host_ref(),
                       ElementC(0), tensor_D_reference.host_ref());

        //
        // Verify equivalence
        //

        // compare
        bool passed = cutlass::reference::host::TensorEquals(
                tensor_D_computed.host_view(), tensor_D_reference.host_view());

        EXPECT_TRUE(passed);

        if (!passed) {
            cutlass::TensorView<ElementA, cutlass::layout::ColumnMajor>
                    tensor_A_physical(tensor_A.host_data(),
                                      tensor_A.stride()[0], tensor_A.extent());

            cutlass::TensorView<ElementB, cutlass::layout::RowMajor>
                    tensor_B_physical(tensor_B.host_data(),
                                      tensor_B.stride()[0], tensor_B.extent());

            std::cout << "cutlass::sizeof_bits<ElementA>::value = "
                      << cutlass::sizeof_bits<ElementA>::value << "\n";
            std::cout << "A:\n"
                      << tensor_A.host_view() << "\n\n"
                      << "A(physical - stride: " << tensor_A.stride()[0]
                      << ", extent: " << tensor_A.extent() << "):\n"
                      << tensor_A_physical << "\n\n";

            std::cout << "cutlass::sizeof_bits<ElementB>::value = "
                      << cutlass::sizeof_bits<ElementB>::value << "\n";
            std::cout << "B:\n"
                      << tensor_B.host_view() << "\n\n"
                      << "B(physical - stride: " << tensor_B.stride()[0]
                      << ", extent: " << tensor_B.extent() << "):\n"
                      << tensor_B_physical << "\n\n";

            std::cout << "C:\n"
                      << tensor_C.host_view() << "\n\n"
                      << "Reference:\n"
                      << tensor_D_reference.host_view() << "\n\n"
                      << "Computed:\n"
                      << tensor_D_computed.host_view() << std::endl;
        }

        return passed;
    }
};

/////////////////////////////////////////////////////////////////////////////////////////////////

/// Structure to compute the matrix product
template <
        /// Warp-level matrix multiply-accumulate
        typename Mma_,
        /// Size of threadblock-scoped shape used to store SMEM
        typename ThreadblockShape_>
struct TransformedTestbedComplex {
    /// Thread-level matrix multiply-accumulate operator
    using Mma = Mma_;
    using ThreadblockShape = ThreadblockShape_;

    using Shape = typename Mma::Shape;
    using ElementA = typename Mma::ElementA;
    using LayoutA = typename Mma::LayoutA;
    using ElementB = typename Mma::ElementB;
    using LayoutB = typename Mma::LayoutB;
    using ElementC = typename Mma::ElementC;
    using LayoutC = typename Mma::LayoutC;

    //
    // Data members
    //

    cutlass::HostTensor<ElementA, LayoutA> tensor_A;
    cutlass::HostTensor<ElementB, LayoutB> tensor_B;
    cutlass::HostTensor<ElementC, LayoutC> tensor_C;
    cutlass::HostTensor<ElementC, LayoutC> tensor_D_computed;
    cutlass::HostTensor<ElementC, LayoutC> tensor_D_reference;

    //
    // Methods
    //

    /// Allocates workspace in device memory
    TransformedTestbedComplex() {
        tensor_A.reset(cutlass::make_Coord(ThreadblockShape::kM,
                                           ThreadblockShape::kK));
        tensor_B.reset(cutlass::make_Coord(ThreadblockShape::kK,
                                           ThreadblockShape::kN));
        tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN));
        tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN));
        tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN),
                                 false);
    }

    /// Runs the test
    bool
    run(cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform,
        cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) {
        //
        // initialize device memory
        //

        if (init_A == cutlass::Distribution::Uniform) {
            uint64_t seed = 7;
            cutlass::reference::host::TensorFillRandomUniform(
                    tensor_A.host_view(), seed, 8, -8, 0);
        } else if (init_A == cutlass::Distribution::Sequential) {
            cutlass::reference::host::BlockFillSequential(tensor_A.host_data(),
                                                          tensor_A.capacity());
        } else if (init_A == cutlass::Distribution::Identity) {
            cutlass::reference::host::TensorFillIdentity(tensor_A.host_view());
        } else {
            // TODO: Implement the rest
            return false;
        }

        if (init_B == cutlass::Distribution::Uniform) {
            uint64_t seed = 7;
            cutlass::reference::host::TensorFillRandomUniform(
                    tensor_B.host_view(), seed + 16, 8, -8, 0);
        } else if (init_B == cutlass::Distribution::Sequential) {
            cutlass::reference::host::BlockFillSequential(tensor_B.host_data(),
                                                          tensor_B.capacity());
        } else if (init_B == cutlass::Distribution::Identity) {
            cutlass::reference::host::TensorFillIdentity(tensor_B.host_view());
        } else {
            // TODO: Implement the rest
            return false;
        }

        cutlass::reference::host::TensorFill(tensor_C.host_view(), ElementC(0));

        cutlass::reference::host::TensorFill(tensor_D_computed.host_view(),
                                             ElementC(0));

        cutlass::reference::host::TensorFill(tensor_D_reference.host_view(),
                                             ElementC(0));

        tensor_A.sync_device();
        tensor_B.sync_device();
        tensor_C.sync_device();
        tensor_D_computed.sync_device();

        // launch kernel
        kernel_transform<Mma, ThreadblockShape><<<dim3(1, 1), dim3(32, 1, 1)>>>(
                tensor_D_computed.device_data(), tensor_A.device_data(),
                tensor_B.device_data(), tensor_C.device_data());

        // verify no errors
        cudaError_t result = cudaDeviceSynchronize();

        EXPECT_EQ(result, cudaSuccess)
                << "CUDA ERROR: " << cudaGetErrorString(result);
        if (result != cudaSuccess) {
            return false;
        }

        tensor_D_computed.sync_host();

        //
        // Reference implementation
        //

        cutlass::reference::host::GemmComplex(
                {Shape::kM, Shape::kN, ThreadblockShape::kK}, ElementC(1),
                tensor_A.host_ref(), Mma::kTransformA, tensor_B.host_ref(),
                Mma::kTransformB, ElementC(0), tensor_C.host_ref(),
                tensor_D_reference.host_ref());

        //
        // Verify equivalence
        //

        // compare
        bool passed = cutlass::reference::host::TensorEquals(
                tensor_D_computed.host_view(), tensor_D_reference.host_view());

        EXPECT_TRUE(passed);

        if (!passed) {
            cutlass::TensorView<ElementA, cutlass::layout::ColumnMajor>
                    tensor_A_physical(tensor_A.host_data(),
                                      tensor_A.stride()[0], tensor_A.extent());

            cutlass::TensorView<ElementB, cutlass::layout::RowMajor>
                    tensor_B_physical(tensor_B.host_data(),
                                      tensor_B.stride()[0], tensor_B.extent());

            std::cout << "cutlass::sizeof_bits<ElementA>::value = "
                      << cutlass::sizeof_bits<ElementA>::value << "\n";
            std::cout << "A:\n"
                      << tensor_A.host_view() << "\n\n"
                      << "A(physical - stride: " << tensor_A.stride()[0]
                      << ", extent: " << tensor_A.extent() << "):\n"
                      << tensor_A_physical << "\n\n";

            std::cout << "cutlass::sizeof_bits<ElementB>::value = "
                      << cutlass::sizeof_bits<ElementB>::value << "\n";
            std::cout << "B:\n"
                      << tensor_B.host_view() << "\n\n"
                      << "B(physical - stride: " << tensor_B.stride()[0]
                      << ", extent: " << tensor_B.extent() << "):\n"
                      << tensor_B_physical << "\n\n";

            std::cout << "C:\n"
                      << tensor_C.host_view() << "\n\n"
                      << "Reference:\n"
                      << tensor_D_reference.host_view() << "\n\n"
                      << "Computed:\n"
                      << tensor_D_computed.host_view() << std::endl;
        }

        return passed;
    }
};

/////////////////////////////////////////////////////////////////////////////////////////////////

/// Test kernel
template <typename Mma, typename ThreadblockShape>
__global__ void sparse_kernel(typename Mma::ElementC* output_C,
                              typename Mma::ElementA const* input_A,
                              typename Mma::ElementB const* input_B,
                              typename Mma::ElementC const* input_C,
                              typename Mma::ElementE const* input_E,
                              int iterations = 1) {
    // Use AlignedBuffer to store trivially copyable objects in unions and
    // __shared__ buffers.
    __shared__ cutlass::AlignedBuffer<
            typename Mma::ElementA,
            ThreadblockShape::kM * ThreadblockShape::kK / Mma::kSparse>
            smem_buffer_A;

    __shared__ cutlass::AlignedBuffer<
            typename Mma::ElementB, ThreadblockShape::kN * ThreadblockShape::kK>
            smem_buffer_B;

    __shared__ cutlass::AlignedBuffer<typename Mma::ElementE,
                                      Mma::Shape::kM * Mma::Shape::kK /
                                              Mma::kSparse /
                                              Mma::kElementsPerElementE>
            smem_buffer_E;

    __syncthreads();

    if (threadIdx.x == 0) {
        typename Mma::ElementA* smem_ptr_A = smem_buffer_A.data();
#pragma unroll 1
        for (int i = 0; i < smem_buffer_A.size(); ++i) {
            cutlass::ReferenceFactory<typename Mma::ElementA>::get(smem_ptr_A,
                                                                   i) =
                    cutlass::ReferenceFactory<
                            typename cutlass::platform::remove_const<
                                    typename Mma::ElementA>::type>::get(input_A,
                                                                        i);
        }

        typename Mma::ElementB* smem_ptr_B = smem_buffer_B.data();
#pragma unroll 1
        for (int i = 0; i < smem_buffer_B.size(); ++i) {
            cutlass::ReferenceFactory<typename Mma::ElementB>::get(smem_ptr_B,
                                                                   i) =
                    cutlass::ReferenceFactory<
                            typename cutlass::platform::remove_const<
                                    typename Mma::ElementB>::type>::get(input_B,
                                                                        i);
        }

        typename Mma::ElementE* smem_ptr_E = smem_buffer_E.data();
#pragma unroll 1
        for (int i = 0; i < smem_buffer_E.size(); ++i) {
            cutlass::ReferenceFactory<typename Mma::ElementE>::get(smem_ptr_E,
                                                                   i) =
                    cutlass::ReferenceFactory<
                            typename cutlass::platform::remove_const<
                                    typename Mma::ElementE>::type>::get(input_E,
                                                                        i);
        }
    }

    __syncthreads();

    //
    // Construct warp-level matrix product
    //

    using FragmentA = typename Mma::FragmentA;
    using FragmentB = typename Mma::FragmentB;
    using FragmentC = typename Mma::FragmentC;
    using FragmentE = typename Mma::FragmentE;

    typename Mma::LayoutA layout_A = Mma::LayoutA::packed(
            {ThreadblockShape::kM, ThreadblockShape::kK / Mma::kSparse});
    typename Mma::LayoutB layout_B =
            Mma::LayoutB::packed({ThreadblockShape::kK, ThreadblockShape::kN});
    typename Mma::LayoutC layout_C =
            Mma::LayoutC::packed({Mma::Shape::kM, Mma::Shape::kN});
    typename Mma::LayoutE layout_E = Mma::LayoutE::packed(
            {Mma::Shape::kM * Mma::kInterleaved,
             Mma::Shape::kK / Mma::kSparse / Mma::kElementsPerElementE /
                     Mma::kInterleaved});

    typename Mma::IteratorA iter_A({smem_buffer_A.data(), layout_A},
                                   cutlass::arch::LaneId());

    typename Mma::IteratorB iter_B({smem_buffer_B.data(), layout_B},
                                   cutlass::arch::LaneId());

    typename Mma::IteratorE iter_E({smem_buffer_E.data(), layout_E},
                                   cutlass::arch::LaneId());

    FragmentA frag_A;
    FragmentB frag_B;

    FragmentC accum;

    FragmentE frag_E;

    Mma mma;

    accum.clear();

    CUTLASS_PRAGMA_NO_UNROLL
    for (int iter = 0; iter < iterations;
         ++iter) {  // place in loop that is not unrolled

        CUTLASS_PRAGMA_UNROLL
        for (int k = 0; k < ThreadblockShape::kK;
             k += Mma::Policy::MmaShape::kK) {
            iter_A.load(frag_A);
            iter_B.load(frag_B);
            iter_E.load(frag_E);

            ++iter_A;
            ++iter_B;
            ++iter_E;

            mma(accum, frag_A, frag_B, accum, frag_E);
        }
    }

    typename Mma::IteratorC iter_C({output_C, layout_C},
                                   cutlass::arch::LaneId());

    iter_C.store(accum);
}

/////////////////////////////////////////////////////////////////////////////////////////////////

/// Structure to compute the matrix product
template <
        /// Warp-level matrix multiply-accumulate
        typename Mma_,
        /// Size of threadblock-scoped shape used to store SMEM
        typename ThreadblockShape_,
        /// The innter product operation performed by GEMM
        typename Operator_ = cutlass::arch::OpMultiplyAdd>
struct SparseTestbed {
    /// Thread-level matrix multiply-accumulate operator
    using Mma = Mma_;
    using ThreadblockShape = ThreadblockShape_;
    using Operator = Operator_;

    using Shape = typename Mma::Shape;
    using ElementA = typename Mma::ElementA;
    using LayoutA = typename Mma::LayoutA;
    using ElementB = typename Mma::ElementB;
    using LayoutB = typename Mma::LayoutB;
    using ElementC = typename Mma::ElementC;
    using LayoutC = typename Mma::LayoutC;

    static int const Sparse = Mma::kSparse;
    static int const MetaSizeInBits = Mma::kMetaSizeInBits;
    static int const MaxID2 = Mma::kMaxID2;
    static int const Interleaved = Mma::kInterleaved;

    using ElementE = typename Mma::ElementE;

    static int const ElementsPerElementE = Mma::kElementsPerElementE;

    using LayoutE = cutlass::layout::RowMajor;
    using ReorderedLayoutE =
            cutlass::layout::ColumnMajorInterleaved<Interleaved>;

    //
    // Data members
    //

    cutlass::HostTensor<ElementA, LayoutA> tensor_A;
    cutlass::HostTensor<ElementA, LayoutA> tensor_A_uncompressed;
    cutlass::HostTensor<ElementB, LayoutB> tensor_B;
    cutlass::HostTensor<ElementC, LayoutC> tensor_C;
    cutlass::HostTensor<ElementC, LayoutC> tensor_D_computed;
    cutlass::HostTensor<ElementC, LayoutC> tensor_D_reference;
    cutlass::HostTensor<ElementE, LayoutE> tensor_E;
    cutlass::HostTensor<ElementE, ReorderedLayoutE> tensor_E_reordered;

    //
    // Methods
    //

    /// Allocates workspace in device memory
    SparseTestbed() {
        tensor_A.reset(cutlass::make_Coord(ThreadblockShape::kM,
                                           ThreadblockShape::kK / Sparse));
        tensor_A_uncompressed.reset(cutlass::make_Coord(ThreadblockShape::kM,
                                                        ThreadblockShape::kK));
        tensor_B.reset(cutlass::make_Coord(ThreadblockShape::kK,
                                           ThreadblockShape::kN));
        tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN));
        tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN));
        tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN),
                                 false);
        tensor_E.reset(cutlass::make_Coord(
                Shape::kM, Shape::kK / Sparse / ElementsPerElementE));
        tensor_E_reordered.reset(cutlass::make_Coord(
                Shape::kM, Shape::kK / Sparse / ElementsPerElementE));
    }

    /// Runs the test
    bool
    run(cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform,
        cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform,
        cutlass::Distribution::Kind init_E = cutlass::Distribution::Uniform) {
        //
        // initialize device memory
        //

        if (init_A == cutlass::Distribution::Uniform) {
            int scope_max = 8;
            int scope_min = -8;

            if (cutlass::sizeof_bits<ElementA>::value == 4) {
                scope_max = 2;
                scope_min = -2;
            } else if (cutlass::sizeof_bits<ElementA>::value == 1) {
                scope_max = 2;
                scope_min = 0;
            }

            uint64_t seed = 7;
            cutlass::reference::host::TensorFillRandomUniform(
                    tensor_A.host_view(), seed, scope_max, scope_min, 0);
        } else if (init_A == cutlass::Distribution::Sequential) {
            cutlass::reference::host::BlockFillSequential(tensor_A.host_data(),
                                                          tensor_A.capacity());
        } else if (init_A == cutlass::Distribution::Identity) {
            cutlass::reference::host::TensorFillIdentity(tensor_A.host_view());
        } else {
            // TODO: Implement the rest
            return false;
        }

        if (init_B == cutlass::Distribution::Uniform) {
            int scope_max = 8;
            int scope_min = -8;

            if (cutlass::sizeof_bits<ElementB>::value == 4) {
                scope_max = 2;
                scope_min = -2;
            } else if (cutlass::sizeof_bits<ElementB>::value == 1) {
                scope_max = 2;
                scope_min = 0;
            }

            uint64_t seed = 7;
            cutlass::reference::host::TensorFillRandomUniform(
                    tensor_B.host_view(), seed + 16, scope_max, scope_min, 0);
        } else if (init_B == cutlass::Distribution::Sequential) {
            cutlass::reference::host::BlockFillSequential(tensor_B.host_data(),
                                                          tensor_B.capacity());
        } else if (init_B == cutlass::Distribution::Identity) {
            cutlass::reference::host::TensorFillIdentity(tensor_B.host_view());
        } else {
            // TODO: Implement the rest
            return false;
        }

        cutlass::reference::host::TensorFill(tensor_C.host_view(), ElementC(0));

        cutlass::reference::host::TensorFill(tensor_D_computed.host_view(),
                                             ElementC(0));

        cutlass::reference::host::TensorFill(tensor_D_reference.host_view(),
                                             ElementC(0));

        if (init_E == cutlass::Distribution::Uniform) {
            uint64_t seed = 7;
            cutlass::reference::host::TensorFillRandomSparseMeta(
                    tensor_E.host_view(), seed, MetaSizeInBits);
        } else if (init_E == cutlass::Distribution::Identity) {
            uint32_t content = (MaxID2 == 1) ? 0x44444444 : 0x4444;
            cutlass::reference::host::TensorFill(tensor_E.host_view(),
                                                 (ElementE)(content));
        } else {
            // TODO: Implement the rest
            return false;
        }

        cutlass::reorder_meta(tensor_E_reordered.host_ref(),
                              tensor_E.host_ref(),
                              {Shape::kM, Shape::kN,
                               Shape::kK / Sparse / ElementsPerElementE});

        tensor_A.sync_device();
        tensor_B.sync_device();
        tensor_C.sync_device();
        tensor_D_computed.sync_device();
        tensor_E_reordered.sync_device();

        // launch kernel
        sparse_kernel<Mma, ThreadblockShape><<<dim3(1, 1), dim3(32, 1, 1)>>>(
                tensor_D_computed.device_data(), tensor_A.device_data(),
                tensor_B.device_data(), tensor_C.device_data(),
                tensor_E_reordered.device_data());

        // verify no errors
        cudaError_t result = cudaDeviceSynchronize();

        EXPECT_EQ(result, cudaSuccess)
                << "CUDA ERROR: " << cudaGetErrorString(result);
        if (result != cudaSuccess) {
            return false;
        }

        tensor_D_computed.sync_host();

        //
        // Reference implementation
        //
        cutlass::uncompress(tensor_A_uncompressed.host_ref(),
                            tensor_A.host_ref(), tensor_E.host_ref(), Shape::kM,
                            Shape::kK);

        cutlass::reference::host::Gemm<ElementA, LayoutA, ElementB, LayoutB,
                                       ElementC, LayoutC, ElementC, ElementC,
                                       Operator>
                reference_gemm;

        reference_gemm({Shape::kM, Shape::kN, ThreadblockShape::kK},
                       ElementC(1), tensor_A_uncompressed.host_ref(),
                       tensor_B.host_ref(), ElementC(0),
                       tensor_D_reference.host_ref());

        //
        // Verify equivalence
        //

        // compare
        bool passed = cutlass::reference::host::TensorEquals(
                tensor_D_computed.host_view(), tensor_D_reference.host_view());

        EXPECT_TRUE(passed);

        if (!passed) {
            std::cout << "cutlass::sizeof_bits<ElementA>::value = "
                      << cutlass::sizeof_bits<ElementA>::value << "\n";
            std::cout << "A:\n" << tensor_A.host_view() << "\n\n";

            std::cout << "cutlass::sizeof_bits<ElementB>::value = "
                      << cutlass::sizeof_bits<ElementB>::value << "\n";
            std::cout << "B:\n" << tensor_B.host_view() << "\n\n";

            std::cout << "cutlass::sizeof_bits<ElementB>::value = "
                      << cutlass::sizeof_bits<ElementE>::value << "\n";
            std::cout << "E:\n" << tensor_E.host_view() << "\n\n";

            std::cout << "C:\n"
                      << tensor_C.host_view() << "\n\n"
                      << "Reference:\n"
                      << tensor_D_reference.host_view() << "\n\n"
                      << "Computed:\n"
                      << tensor_D_computed.host_view() << "\n";
        }

        return passed;
    }
};

/////////////////////////////////////////////////////////////////////////////////////////////////

}  // namespace warp
}  // namespace gemm
}  // namespace test
